from turtle import color
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns
from sklearn.preprocessing import MinMaxScaler

import time

from rlData_Model import LSTM #self define

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device:{device}")

file_path="/mnt/home/linxy/LSTM_Learning/rlData.csv"
data=pd.read_csv(file_path)

data=data.sort_values("Date")

print(data.head())

print(data.shape)

plt.figure(figsize=(15,9))
plt.plot(data[["Close"]])
plt.xticks(range(0,data.shape[0],20),data["Date"].loc[::20],rotation=45)
plt.title("*****Stock Price",fontsize=18,fontweight="bold")
plt.xlabel("Date",fontsize=18)
plt.ylabel('Close Price (USD)', fontsize=18)
plt.savefig('/mnt/home/linxy/LSTM_Learning/StockPrice.jpg')
plt.show()

price=data[["Close"]]
print(price.info())

scaler=MinMaxScaler(feature_range=(-1,1))
price["Close"]=scaler.fit_transform(price["Close"].values.reshape(-1,1))
print(price["Close"].shape)

def split_data(stock,lookback):
    data_raw=stock.to_numpy()
    data=[]

    for index in range(len(data_raw)-lookback):
        data.append(data_raw[index:index+lookback])

    data=np.array(data)
    print(type(data))

    test_set_size=int(np.round(0.2*data.shape[0]))
    train_set_size=data.shape[0]-test_set_size

    x_train=data[:train_set_size,:-1,:]
    y_train=data[:train_set_size,-1,:]

    x_test=data[train_set_size:,:-1,:]
    y_test=data[train_set_size:,-1,:]

    return [x_train,y_train,x_test,y_test]

lookback=20
x_train, y_train, x_test, y_test = split_data(price, lookback)
print('x_train.shape = ', x_train.shape)
print('y_train.shape = ', y_train.shape)
print('x_test.shape = ', x_test.shape)
print('y_test.shape = ', y_test.shape)

x_train=torch.from_numpy(x_train).type(torch.Tensor).to(device)
x_test=torch.from_numpy(x_test).type(torch.Tensor)
y_train_lstm = torch.from_numpy(y_train).type(torch.Tensor).to(device)
y_test_lstm = torch.from_numpy(y_test).type(torch.Tensor)
y_train_gru = torch.from_numpy(y_train).type(torch.Tensor)
y_test_gru = torch.from_numpy(y_test).type(torch.Tensor)

input_dim=1
hidden_dim=32
num_layers=2
output_dim=1
num_epochs=100

model=LSTM(input_dim=input_dim,hidden_dim=hidden_dim,output_dim=output_dim,num_layers=num_layers)
model=model.to(device)

criterion=nn.MSELoss()
optimiser=torch.optim.Adam(model.parameters(),lr=0.01)

hist=np.zeros(num_epochs)
start_time=time.time()
lstm=[]

for t in range(num_epochs):
    y_train_pred=model(x_train)

    loss=criterion(y_train_pred,y_train_lstm)
    print("Epoch ",t,"MSE:",loss.item())
    hist[t]=loss.item()
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

training_time=time.time()-start_time
print("Training time:{}".format(training_time))

predict = pd.DataFrame(scaler.inverse_transform(y_train_pred.detach().cpu().numpy()))
print(predict)
original=pd.DataFrame(scaler.inverse_transform(y_train_lstm.detach().cpu().numpy()))
print(original)

sns.set_style("darkgrid")
fig=plt.figure()
fig.subplots_adjust(wspace=0.2,hspace=0.2)

plt.subplot(1,2,1)
ax=sns.lineplot(x=original.index,y=original[0],label="Data",color="royalblue")
ax=sns.lineplot(x=predict.index,y=predict[0],label="Training Prediction (LSTM)",color="tomato")
print(predict.index)
print("aaaaa")
print(predict[0])

ax.set_title("Stock price",size=14,fontweight="bold")
ax.set_xlabel("Days", size = 14)
ax.set_ylabel("Cost (USD)", size = 14)
ax.set_xticklabels('', size=10)

plt.subplot(1, 2, 2)
ax = sns.lineplot(data=hist, color='royalblue')
ax.set_xlabel("Epoch", size = 14)
ax.set_ylabel("Loss", size = 14)
ax.set_title("Training Loss", size = 14, fontweight='bold')
fig.set_figheight(6)
fig.set_figwidth(16)
plt.savefig('/mnt/home/linxy/LSTM_Learning/figure.png')

import math, time
from sklearn.metrics import mean_squared_error

# make predictions
y_test_pred = model(x_test)

# invert predictions
y_train_pred = scaler.inverse_transform(y_train_pred.detach().numpy())
y_train = scaler.inverse_transform(y_train_lstm.detach().numpy())
y_test_pred = scaler.inverse_transform(y_test_pred.detach().numpy())
y_test = scaler.inverse_transform(y_test_lstm.detach().numpy())

# calculate root mean squared error
trainScore = math.sqrt(mean_squared_error(y_train[:,0], y_train_pred[:,0]))
print('Train Score: %.2f RMSE' % (trainScore))
testScore = math.sqrt(mean_squared_error(y_test[:,0], y_test_pred[:,0]))
print('Test Score: %.2f RMSE' % (testScore))
lstm.append(trainScore)
lstm.append(testScore)
lstm.append(training_time)
# shift train predictions for plotting
trainPredictPlot = np.empty_like(price)
trainPredictPlot[:, :] = np.nan
trainPredictPlot[lookback:len(y_train_pred)+lookback, :] = y_train_pred

# shift test predictions for plotting
testPredictPlot = np.empty_like(price)
testPredictPlot[:, :] = np.nan
testPredictPlot[len(y_train_pred)+lookback-1:len(price)-1, :] = y_test_pred

original = scaler.inverse_transform(price['Close'].values.reshape(-1,1))

predictions = np.append(trainPredictPlot, testPredictPlot, axis=1)
predictions = np.append(predictions, original, axis=1)
result = pd.DataFrame(predictions)

import plotly.express as px
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(go.Scatter(go.Scatter(x=result.index, y=result[0],
                    mode='lines',
                    name='Train prediction')))
fig.add_trace(go.Scatter(x=result.index, y=result[1],
                    mode='lines',
                    name='Test prediction'))
fig.add_trace(go.Scatter(go.Scatter(x=result.index, y=result[2],
                    mode='lines',
                    name='Actual Value')))
fig.update_layout(
    xaxis=dict(
        showline=True,
        showgrid=True,
        showticklabels=False,
        linecolor='white',
        linewidth=2
    ),
    yaxis=dict(
        title_text='Close (USD)',
        titlefont=dict(
            family='Rockwell',
            size=12,
            color='white',
        ),
        showline=True,
        showgrid=True,
        showticklabels=True,
        linecolor='white',
        linewidth=2,
        ticks='outside',
        tickfont=dict(
            family='Rockwell',
            size=12,
            color='white',
        ),
    ),
    showlegend=True,
    template = 'plotly_dark'

)



annotations = []
annotations.append(dict(xref='paper', yref='paper', x=0.0, y=1.05,
                              xanchor='left', yanchor='bottom',
                              text='Results (LSTM)',
                              font=dict(family='Rockwell',
                                        size=26,
                                        color='white'),
                              showarrow=False))
fig.update_layout(annotations=annotations)

fig.show()
